package org.neuroph.core.learning;

import com.google.firebase.remoteconfig.FirebaseRemoteConfig;
import java.io.Serializable;
import java.util.Iterator;
import org.neuroph.core.Connection;
import org.neuroph.core.Layer;
import org.neuroph.core.Neuron;
import org.neuroph.core.Weight;
import org.neuroph.core.data.DataSet;
import org.neuroph.core.data.DataSetRow;
import org.neuroph.core.learning.error.ErrorFunction;
import org.neuroph.core.learning.error.MeanSquaredError;
import org.neuroph.core.learning.stop.MaxErrorStop;

/* loaded from: classes2.dex */
public abstract class SupervisedLearning extends IterativeLearning implements Serializable {
    private static final long serialVersionUID = 3;
    private ErrorFunction errorFunction;
    private transient int minErrorChangeIterationsCount;
    protected transient double previousEpochError;
    protected transient double totalNetworkError;
    protected transient double totalSquaredErrorSum;
    private int trainingSetSize;
    protected double maxError = 0.01d;
    private double minErrorChange = Double.POSITIVE_INFINITY;
    private int minErrorChangeIterationsLimit = Integer.MAX_VALUE;
    private boolean batchMode = false;

    /* JADX INFO: Access modifiers changed from: protected */
    public void addToSquaredErrorSum(double[] dArr) {
        double d = FirebaseRemoteConfig.DEFAULT_VALUE_FOR_DOUBLE;
        for (double d2 : dArr) {
            d += d2 * d2 * 0.5d;
        }
        this.totalSquaredErrorSum += d;
    }

    @Override // org.neuroph.core.learning.IterativeLearning
    protected void afterEpoch() {
        if (Math.abs(this.previousEpochError - this.totalNetworkError) <= this.minErrorChange) {
            this.minErrorChangeIterationsCount++;
        } else {
            this.minErrorChangeIterationsCount = 0;
        }
        if (this.batchMode) {
            doBatchWeightsUpdate();
        }
    }

    @Override // org.neuroph.core.learning.IterativeLearning
    protected void beforeEpoch() {
        this.previousEpochError = this.totalNetworkError;
        this.totalNetworkError = FirebaseRemoteConfig.DEFAULT_VALUE_FOR_DOUBLE;
        this.totalSquaredErrorSum = FirebaseRemoteConfig.DEFAULT_VALUE_FOR_DOUBLE;
        this.errorFunction.reset();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public double[] calculateOutputError(double[] dArr, double[] dArr2) {
        double[] dArr3 = new double[dArr.length];
        for (int i = 0; i < dArr2.length; i++) {
            dArr3[i] = dArr[i] - dArr2[i];
        }
        return dArr3;
    }

    protected void doBatchWeightsUpdate() {
        Layer[] layers = this.neuralNetwork.getLayers();
        for (int layersCount = this.neuralNetwork.getLayersCount() - 1; layersCount > 0; layersCount--) {
            for (Neuron neuron : layers[layersCount].getNeurons()) {
                for (Connection connection : neuron.getInputConnections()) {
                    Weight weight = connection.getWeight();
                    weight.value += weight.weightChange;
                    weight.weightChange = FirebaseRemoteConfig.DEFAULT_VALUE_FOR_DOUBLE;
                }
            }
        }
    }

    @Override // org.neuroph.core.learning.IterativeLearning
    public void doLearningEpoch(DataSet dataSet) {
        Iterator<DataSetRow> it = dataSet.iterator();
        while (it.hasNext() && !isStopped()) {
            learnPattern(it.next());
        }
        this.totalNetworkError = this.errorFunction.getTotalError();
    }

    public ErrorFunction getErrorFunction() {
        return this.errorFunction;
    }

    public double getMaxError() {
        return this.maxError;
    }

    public double getMinErrorChange() {
        return this.minErrorChange;
    }

    public int getMinErrorChangeIterationsCount() {
        return this.minErrorChangeIterationsCount;
    }

    public int getMinErrorChangeIterationsLimit() {
        return this.minErrorChangeIterationsLimit;
    }

    public double getPreviousEpochError() {
        return this.previousEpochError;
    }

    public synchronized double getTotalNetworkError() {
        return this.totalNetworkError;
    }

    public boolean isInBatchMode() {
        return this.batchMode;
    }

    public void learn(DataSet dataSet, double d) {
        this.maxError = d;
        learn(dataSet);
    }

    public void learn(DataSet dataSet, double d, int i) {
        this.maxError = d;
        setMaxIterations(i);
        learn(dataSet);
    }

    protected void learnPattern(DataSetRow dataSetRow) {
        this.neuralNetwork.setInput(dataSetRow.getInput());
        this.neuralNetwork.calculate();
        double[] calculateOutputError = calculateOutputError(dataSetRow.getDesiredOutput(), this.neuralNetwork.getOutput());
        this.errorFunction.addOutputError(calculateOutputError);
        updateNetworkWeights(calculateOutputError);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.neuroph.core.learning.IterativeLearning, org.neuroph.core.learning.LearningRule
    public void onStart() {
        super.onStart();
        this.minErrorChangeIterationsCount = 0;
        this.totalNetworkError = FirebaseRemoteConfig.DEFAULT_VALUE_FOR_DOUBLE;
        this.previousEpochError = FirebaseRemoteConfig.DEFAULT_VALUE_FOR_DOUBLE;
        int size = getTrainingSet().size();
        this.trainingSetSize = size;
        this.errorFunction = new MeanSquaredError(size);
        this.stopConditions.add(new MaxErrorStop(this));
    }

    public void setBatchMode(boolean z) {
        this.batchMode = z;
    }

    public void setErrorFunction(ErrorFunction errorFunction) {
        this.errorFunction = errorFunction;
    }

    public void setMaxError(double d) {
        this.maxError = d;
    }

    public void setMinErrorChange(double d) {
        this.minErrorChange = d;
    }

    public void setMinErrorChangeIterationsLimit(int i) {
        this.minErrorChangeIterationsLimit = i;
    }

    protected abstract void updateNetworkWeights(double[] dArr);
}
